import csv
import os
import re
import sys
import time
import math
import datetime as dt

r1 = re.compile(r'Transport,[0-9]*,('
                r'172\.174\.255\.187|'
                r'172\.178\.73\.102|'
                r'4\.246\.128\.124|'
                r'4\.246\.128\.151|'
                r'4\.246\.129\.62|'
                r'4\.246\.130\.139|'
                r'52\.186\.182\.118'
                r')]')
r2 = re.compile(r'BACKEND: [a-zA-Z0-9\-_\.]+')
r3 = re.compile(r'Transport,[0-9]+')

class RunningStdevAvg:
    def __init__(self, n=0):
        self.n = n  # Number of data points, TODO: SET n TO 1? COUNTING LOGIN AS 1 INPUT???
        self.mean = 0.0  # Running mean
        self.M2 = 0.0  # Running variance (M2)

    def update(self, new_value):
        self.n += 1
        delta = new_value - self.mean
        self.mean += delta / self.n
        delta2 = new_value - self.mean
        self.M2 += delta * delta2

    def current_mean(self):
        return self.mean

    def current_stddev(self):
        if self.n < 2:
            return 0.0  # Standard deviation is not defined for a single data point
        return math.sqrt(self.M2 / (self.n - 1))


class LogData():

    def __init__(self):
        self.id = 'n/a'
        self.type = 'none'  # cowrie, llmv1, llmv2 (to be extended with more models)
        self.user = 'none'
        self.password = 'none'
        self.window_size = 'n/a'
        self.client_ver = "n/a"
        self.client_ver_hash = "n/a"
        self.pub_key = "ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff"

        self.start_time = 0.0
        self.end_time = 0.0
        self.last_input_time = 0.0
        self.last_input_epoch = 0.0

        self.session_time = 0.0
        self.session_time_last = 0.0

        self.last_time_dict = {'token': 0.0, 'cmd': 0.0, 'exec_time': 0.0}

        self.total_tokens_used = 0

        self.req_std_avg_tokens_used = RunningStdevAvg()
        self.req_min_tokens_used = 100000000.0
        self.req_max_tokens_used = -100000000.0

        self.output_std_avg_tokens_used = RunningStdevAvg()
        self.output_min_tokens_used = 100000000.0
        self.output_max_tokens_used = -100000000.0

        self.total_num_inputs = 0  # number of lines

        self.total_num_cmds = 0  # number of commands (split line on logic ops and semicolon)
        self.std_avg_cmds_input = RunningStdevAvg()
        self.min_cmds_input = 100000000.0
        self.max_cmds_input = -100000000.0

        self.total_iat = 0.0
        self.std_avg_iat = RunningStdevAvg()
        self.min_iat = 100000000.0
        self.max_iat = -100000000.0

        self.total_exec_time = 0.0  # cowrie wait or llm resp time
        self.std_avg_exec_time = RunningStdevAvg()
        self.min_exec_time = 100000000.0
        self.max_exec_time = -100000000.0

        self.total_wait_time = 0.0
        self.std_avg_wait_time = RunningStdevAvg()
        self.min_wait_time = 100000000.0
        self.max_wait_time = -100000000.0

        self.input_list = []
        self.iat_time_list = []
        self.exec_time_list = []
        self.wait_time_list = []
        self.req_token_amt_list = []
        self.output_token_amt_list = []

    def __repr__(self):
        return 'tokens_used: {}\nllm_resp_time: {}\nnum_inputs: {}\nsession_time: {}\ninput_list: {}\n'.format(self.total_tokens_used, self.total_exec_time, self.total_num_inputs, self.session_time, self.input_list)

def cowrie_log(filtered_log, log_data, d1: str):
    for session in filtered_log.keys():

        if session not in log_data.keys():  # init class for storing data for each session
            log_data[session] = LogData()
            log_data[session].date = d1.strip()

            log_data[session].id = session.strip()
            log_data[session].name = '_'.join(session.split('_')[:-2]).strip()

            utc_time = dt.datetime.strptime(filtered_log[session][0].split(' ')[0], "%Y-%m-%dT%H:%M:%S.%fZ")  # time from timestamp at start of logline
            epoch_time = (utc_time - dt.datetime(1970, 1, 1)).total_seconds()
            log_data[session].start_time = epoch_time
            log_data[session].end_time = epoch_time
            log_data[session].last_input_epoch = epoch_time

        for line in filtered_log[session]:

            if 'Remote SSH version: ' in line:
                log_data[session].client_ver = line[line.find('version:'):].split(': ')[-1].rstrip('\n')

            elif 'SSH client hassh fingerprint: ' in line:
                log_data[session].client_ver_hash = line[line.find('fingerprint:'):].split(': ')[-1].rstrip('\n')

            elif 'TERM_SIZE:' in line:
                log_data[session].window_size = line[line.find('TERM_SIZE:'):].split(':')[-1].replace('<del>\n', "")

            elif 'public key attempt for user' in line:
                log_data[session].pub_key = line[line.find('fingerprint'):].split(' ')[-1].rstrip('\n')

            elif ' login attempt [' in line:
                ext_line = line.split('[')[2].split(']')[0]
                log_data[session].user = ext_line.split("'")[1]
                log_data[session].password = ext_line.split("'")[3]

            elif 'logging out' in line:  # should trigger before connection lost
                utc_time = dt.datetime.strptime(line.split(' ')[0], "%Y-%m-%dT%H:%M:%S.%fZ")  # time from timestamp at start of logline
                epoch_time = (utc_time - dt.datetime(1970, 1, 1)).total_seconds()
                log_data[session].end_time = epoch_time
                log_data[session].session_time = log_data[session].end_time - log_data[session].start_time

                last_wait = epoch_time - log_data[session].last_input_epoch + 18000  # 18000 for timezone

                # CALCULATING LAST WAIT + ADDING TO HONEY WAIT

                log_data[session].session_time_last = log_data[session].total_exec_time + log_data[session].total_wait_time + last_wait

                log_data[session].total_wait_time += last_wait
                log_data[session].std_avg_wait_time.update(last_wait)

                if last_wait > log_data[session].max_wait_time:
                    log_data[session].max_wait_time = last_wait
                if last_wait < log_data[session].min_wait_time:
                    log_data[session].min_wait_time = last_wait

                log_data[session].wait_time_list.append(last_wait)
                break  # end of session

            elif 'Connection lost after' in line:  # should only trigger if no graceful logout
                utc_time = dt.datetime.strptime(line.split(' ')[0], "%Y-%m-%dT%H:%M:%S.%fZ")  # time from timestamp at start of logline
                epoch_time = (utc_time - dt.datetime(1970, 1, 1)).total_seconds()
                log_data[session].end_time = epoch_time
                log_data[session].session_time = log_data[session].end_time - log_data[session].start_time

                last_wait = epoch_time - log_data[session].last_input_epoch + 18000  # 18000 for timezone

                # CALCULATING LAST WAIT + ADDING TO HONEY WAIT

                log_data[session].session_time_last = log_data[session].total_exec_time + log_data[session].total_wait_time + last_wait

                log_data[session].total_wait_time += last_wait
                log_data[session].std_avg_wait_time.update(last_wait)

                if last_wait > log_data[session].max_wait_time:
                    log_data[session].max_wait_time = last_wait
                if last_wait < log_data[session].min_wait_time:
                    log_data[session].min_wait_time = last_wait

                log_data[session].wait_time_list.append(last_wait)
                break  # end of session

            elif 'INPUT_CMD:' in line or (old_ver_switch and 'CMD' in line):
                log_data[session].total_num_inputs += 1
                line_split = line.split(' ')
                if old_ver_switch:
                    curr_cmd_line = ' '.join(line_split[line_split.index('CMD:') + 1:]).replace('<del>\n', "")
                else:
                    curr_cmd_line = ' '.join(line_split[line_split.index('INPUT_CMD:') + 1:]).replace('<del>\n', "")
                cmd_list = curr_cmd_line.replace('&&', ';').replace('||', ';').split(';')

                log_data[session].total_num_cmds += len(cmd_list)
                log_data[session].std_avg_cmds_input.update(len(cmd_list))

                if len(cmd_list) > log_data[session].max_cmds_input:
                    log_data[session].max_cmds_input = len(cmd_list)
                if len(cmd_list) < log_data[session].min_cmds_input:
                    log_data[session].min_cmds_input = len(cmd_list)

                log_data[session].input_list.append(curr_cmd_line)

            elif 'CMD_IAT_TIME:' in line:
                line_split = line.split(' ')
                curr_iat = float(' '.join(line_split[line_split.index('CMD_IAT_TIME:') + 1:]).replace('<del>\n', ""))

                log_data[session].total_iat += curr_iat
                log_data[session].std_avg_iat.update(curr_iat)

                if curr_iat > log_data[session].max_iat:
                    log_data[session].max_iat = curr_iat
                if curr_iat < log_data[session].min_iat:
                    log_data[session].min_iat = curr_iat

                log_data[session].iat_time_list.append(curr_iat)

            elif 'CMD_WAIT_TIME:' in line:
                line_split = line.split(' ')
                curr_wait = float(' '.join(line_split[line_split.index('CMD_WAIT_TIME:') + 1:]).replace('<del>\n', ""))

                log_data[session].total_wait_time += curr_wait
                log_data[session].std_avg_wait_time.update(curr_wait)

                if curr_wait > log_data[session].max_wait_time:
                    log_data[session].max_wait_time = curr_wait
                if curr_wait < log_data[session].min_wait_time:
                    log_data[session].min_wait_time = curr_wait

                log_data[session].wait_time_list.append(curr_wait)


                # used for final session time last calc
                curr_epoch = float(line_split[line_split.index('TIME:') + 1])
                log_data[session].last_input_epoch = curr_epoch

            elif 'CMD_EXEC_TIME:' in line:
                line_split = line.split(' ')
                curr_exec = float(' '.join(line_split[line_split.index('CMD_EXEC_TIME:') + 1:]).replace('<del>\n', "").split(' ')[0])

                log_data[session].total_exec_time += curr_exec
                log_data[session].std_avg_exec_time.update(curr_exec)

                if curr_exec > log_data[session].max_exec_time:
                    log_data[session].max_exec_time = curr_exec
                if curr_exec < log_data[session].min_exec_time:
                    log_data[session].min_exec_time = curr_exec

                log_data[session].exec_time_list.append(curr_exec)

                # used for final session time last calc
                curr_epoch = float(line_split[line_split.index('TIME:') + 1])
                log_data[session].last_input_epoch = curr_epoch

            elif 'REQ_TOKEN_USE:' in line:
                line_split = line.split(' ')
                curr_tok = float(' '.join(line_split[line_split.index('REQ_TOKEN_USE:') + 1:]).replace('<del>\n', ""))

                log_data[session].total_tokens_used += curr_tok

                log_data[session].req_std_avg_tokens_used.update(curr_tok)
                if curr_tok > log_data[session].req_max_tokens_used:
                    log_data[session].req_max_tokens_used = curr_tok
                if curr_tok < log_data[session].req_min_tokens_used:
                    log_data[session].req_min_tokens_used = curr_tok

                log_data[session].req_token_amt_list.append(curr_tok)

            elif 'OUTPUT_TOKEN_USE:' in line:
                line_split = line.split(' ')
                curr_tok = float(' '.join(line_split[line_split.index('OUTPUT_TOKEN_USE:') + 1:]).replace('<del>\n', ""))

                log_data[session].total_tokens_used += curr_tok

                log_data[session].output_std_avg_tokens_used.update(curr_tok)
                if curr_tok > log_data[session].output_max_tokens_used:
                    log_data[session].output_max_tokens_used = curr_tok
                if curr_tok < log_data[session].output_min_tokens_used:
                    log_data[session].output_min_tokens_used = curr_tok

                log_data[session].output_token_amt_list.append(curr_tok)

            #print(line)


backends = ['basic', 'llmv1', 'llmv2', 'llmv3_0.2', 'llmv3_0.5', 'llmv3_0.9']
honeydirs = ['gpt1', 'gpt2', 'gpt3', 'gpt4', 'gpt5', 'gpt6']
old_ver_switch = False

def main():
    global honeydirs
    global old_ver_switch
    writefile = "data_all.csv"

    if len(sys.argv) == 2:
        writefile = "data.csv"
        honeydirs = ['hih']
        old_ver_switch = True

    data = []
    for dir in honeydirs:
        logfiles = os.listdir(f'gpt_cowrie_logs/{dir}')
        for logfile in logfiles:
            d1 = logfile.split('.')[-1].strip()
            if d1 == 'log':
                d1 = dt.datetime.now().strftime("%Y-%m-%d")
            if d1 < '2024-04-22':
                continue

            log_data = {}
            filtered_log = {}

            loglines = open(f"/mnt/c/Users/jbrag/Desktop/cowrie_log_extract/gpt_cowrie_logs/{dir}/{logfile}", 'r').readlines()

            for line in loglines:
                if re.search(r1, line) and (re.search(r2, line) or old_ver_switch):
                    backend = 'hih'
                    if not old_ver_switch:
                        backend = ''.join(re.findall(r2, line)).split(':')[1]
                    session_id = ''.join(re.findall(r3, line)).split(',')[1]
                    if 'SSH' in line:
                        session_id = '{}_SSH_{}'.format(backend, session_id)
                    else:
                        session_id = '{}_TEL_{}'.format(backend, session_id)

                    if session_id not in filtered_log.keys():
                        filtered_log[session_id] = [line]
                    else:
                        filtered_log[session_id].append(line)

            cowrie_log(filtered_log, log_data, d1)

            for key, obj in dict(log_data).items():
                if obj.total_num_inputs == 0:
                    del log_data[key]
            print(f"{dir}/{logfile} read")
            data.append(log_data)

    with open(f'/mnt/c/Users/jbrag/Desktop/cowrie_log_extract/gpt_cowrie_logs/{writefile}', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['session_id',
                         'backend',
                         'date',
                         'start_time',
                         'end_time',
                         'session_time_on_logout',
                         'session_time_on_last',
                         'username',
                         'password',
                         'public_key',
                         'SSH version',
                         'SSH version hash',
                         'window size',
                         'total_inputs',
                         'total_cmds',
                         'avg_cmds_input',
                         'std_cmds_input',
                         'min_cmds_input',
                         'max_cmds_inputs',
                         'total_iat',
                         'avg_iat',
                         'std_iat',
                         'min_iat',
                         'max_iat',
                         'honey_exec_time',
                         'avg_honey_exec_time',
                         'std_honey_exec_time',
                         'min_honey_exec_time',
                         'max_honey_exec_time',
                         'honey_wait_time',
                         'avg_honey_wait_time',
                         'std_avg_honey_wait_time',
                         'min_honey_wait_time',
                         'max_honey_wait_time',
                         'total_tokens_used',
                         'req_avg_tokens_used',
                         'req_std_tokens_used',
                         'req_min_tokens_used',
                         'req_max_tokens_used',
                         'output_avg_tokens_used',
                         'output_std_tokens_used',
                         'output_min_tokens_used',
                         'output_max_tokens_used',
                         'input_list',
                         'iat_list',
                         'wait_list',
                         'exec_list',
                         'req_token_list',
                         'output_token_list'])
        for d_ in data:
            for session_id in d_.keys():
                writer.writerow([d_[session_id].id,
                                 d_[session_id].name,
                                 d_[session_id].date,
                                 d_[session_id].start_time,
                                 d_[session_id].end_time,
                                 d_[session_id].session_time,
                                 d_[session_id].session_time_last,
                                 d_[session_id].user,
                                 d_[session_id].password,
                                 d_[session_id].pub_key,
                                 d_[session_id].client_ver,
                                 d_[session_id].client_ver_hash,
                                 d_[session_id].window_size,
                                 d_[session_id].total_num_inputs,
                                 d_[session_id].total_num_cmds,
                                 d_[session_id].std_avg_cmds_input.current_mean(),
                                 d_[session_id].std_avg_cmds_input.current_stddev(),
                                 d_[session_id].min_cmds_input,
                                 d_[session_id].max_cmds_input,
                                 d_[session_id].total_iat,
                                 d_[session_id].std_avg_iat.current_mean(),
                                 d_[session_id].std_avg_iat.current_stddev(),
                                 d_[session_id].min_iat,
                                 d_[session_id].max_iat,
                                 d_[session_id].total_exec_time,
                                 d_[session_id].std_avg_exec_time.current_mean(),
                                 d_[session_id].std_avg_exec_time.current_stddev(),
                                 d_[session_id].min_exec_time,
                                 d_[session_id].max_exec_time,
                                 d_[session_id].total_wait_time,
                                 d_[session_id].std_avg_wait_time.current_mean(),
                                 d_[session_id].std_avg_wait_time.current_stddev(),
                                 d_[session_id].min_wait_time,
                                 d_[session_id].max_wait_time,
                                 d_[session_id].total_tokens_used,

                                 d_[session_id].req_std_avg_tokens_used.current_mean(),
                                 d_[session_id].req_std_avg_tokens_used.current_stddev(),
                                 d_[session_id].req_min_tokens_used,
                                 d_[session_id].req_max_tokens_used,

                                 d_[session_id].output_std_avg_tokens_used.current_mean(),
                                 d_[session_id].output_std_avg_tokens_used.current_stddev(),
                                 d_[session_id].output_min_tokens_used,
                                 d_[session_id].output_max_tokens_used,

                                 d_[session_id].input_list,
                                 d_[session_id].iat_time_list,
                                 d_[session_id].wait_time_list,
                                 d_[session_id].exec_time_list,
                                 d_[session_id].req_token_amt_list,
                                 d_[session_id].output_token_amt_list])


if __name__ == "__main__":
    main()
